{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN Classification Models\n",
    "This example shows the application of RNN models in river-torch with and without usage of an incremental class adaption strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:14:09.923392Z",
     "iopub.status.busy": "2025-09-28T08:14:09.923081Z",
     "iopub.status.idle": "2025-09-28T08:14:12.468089Z",
     "shell.execute_reply": "2025-09-28T08:14:12.467519Z"
    }
   },
   "outputs": [],
   "source": [
    "from deep_river.classification import RollingClassifierInitialized\n",
    "from river import metrics, preprocessing, datasets\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:14:12.470870Z",
     "iopub.status.busy": "2025-09-28T08:14:12.470611Z",
     "iopub.status.idle": "2025-09-28T08:14:12.475372Z",
     "shell.execute_reply": "2025-09-28T08:14:12.474764Z"
    }
   },
   "outputs": [],
   "source": [
    "class RnnModule(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, n_features, hidden_size=16, num_layers=1):\n",
    "        super().__init__()\n",
    "        self.num_layers = num_layers\n",
    "        self.n_features = n_features\n",
    "        self.hidden_size = hidden_size\n",
    "        self.rnn = torch.nn.RNN(\n",
    "            input_size=n_features, hidden_size=hidden_size, num_layers=num_layers,\n",
    "        )\n",
    "        self.linear = torch.nn.Linear(hidden_size, 2)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        out, hn = self.rnn(X) \n",
    "        hn = hn[-1]  # Take the last hidden state\n",
    "        out = self.linear(hn)\n",
    "        return torch.nn.functional.softmax(out, dim=-1)  # Return class probabilities"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification without incremental class adapation strategy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:14:12.477918Z",
     "iopub.status.busy": "2025-09-28T08:14:12.477734Z",
     "iopub.status.idle": "2025-09-28T08:14:13.110995Z",
     "shell.execute_reply": "2025-09-28T08:14:13.110481Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RollingClassifierInitialized</pre></summary><code class=\"river-estimator-params\">RollingClassifierInitialized (\n",
       "  module=RnnModule(\n",
       "  (rnn): RNN(31, 16, num_layers=2)\n",
       "  (linear): Linear(in_features=16, out_features=2, bias=True)\n",
       ")\n",
       "  loss_fn=\"binary_cross_entropy\"\n",
       "  optimizer_fn=&lt;class 'torch.optim.sgd.SGD'&gt;\n",
       "  lr=0.01\n",
       "  output_is_logit=True\n",
       "  is_class_incremental=False\n",
       "  is_feature_incremental=False\n",
       "  device=\"cpu\"\n",
       "  seed=42\n",
       "  window_size=20\n",
       "  append_predict=True\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  RollingClassifierInitialized (\n",
       "    module=RnnModule(\n",
       "    (rnn): RNN(31, 16, num_layers=2)\n",
       "    (linear): Linear(in_features=16, out_features=2, bias=True)\n",
       "  )\n",
       "    loss_fn=\"binary_cross_entropy\"\n",
       "    optimizer_fn=<class 'torch.optim.sgd.SGD'>\n",
       "    lr=0.01\n",
       "    output_is_logit=True\n",
       "    is_class_incremental=False\n",
       "    is_feature_incremental=False\n",
       "    device=\"cpu\"\n",
       "    seed=42\n",
       "    window_size=20\n",
       "    append_predict=True\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = datasets.Keystroke()\n",
    "metric = metrics.Accuracy()\n",
    "optimizer_fn = torch.optim.SGD\n",
    "\n",
    "model_pipeline = preprocessing.StandardScaler()\n",
    "model_pipeline |= RollingClassifierInitialized(\n",
    "    module=RnnModule(n_features=31, hidden_size=16, num_layers=2),\n",
    "    loss_fn=\"binary_cross_entropy\",\n",
    "    optimizer_fn=torch.optim.SGD,\n",
    "    window_size=20,\n",
    "    lr=1e-2,\n",
    "    append_predict=True,\n",
    "    is_class_incremental=False,\n",
    ")\n",
    "model_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:14:13.113662Z",
     "iopub.status.busy": "2025-09-28T08:14:13.113415Z",
     "iopub.status.idle": "2025-09-28T08:15:14.112894Z",
     "shell.execute_reply": "2025-09-28T08:15:14.112147Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.04\n"
     ]
    }
   ],
   "source": [
    "for x, y in dataset:\n",
    "    y_pred = model_pipeline.predict_one(x)  # make a prediction\n",
    "    metric.update(y, y_pred)  # update the metric\n",
    "    model_pipeline.learn_one(x, y)  # make the model learn\n",
    "\n",
    "print(f\"Accuracy: {metric.get():.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LSTM Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:15:14.115893Z",
     "iopub.status.busy": "2025-09-28T08:15:14.115623Z",
     "iopub.status.idle": "2025-09-28T08:15:14.121384Z",
     "shell.execute_reply": "2025-09-28T08:15:14.120503Z"
    }
   },
   "outputs": [],
   "source": [
    "class LSTMModule(torch.nn.Module):\n",
    "    def __init__(self, n_features, hidden_size=4):\n",
    "        super().__init__()\n",
    "        self.n_features = n_features\n",
    "        self.hidden_size = hidden_size\n",
    "        self.lstm = torch.nn.LSTM(\n",
    "            input_size=n_features, hidden_size=hidden_size, num_layers=1\n",
    "        )\n",
    "        self.linear = torch.nn.Linear(hidden_size, 2)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        # lstm with input, hidden, and internal state\n",
    "        output, (hn, cn) = self.lstm(X)\n",
    "        x = hn.view(-1, self.hidden_size)\n",
    "        x = self.linear(x)\n",
    "        return torch.nn.functional.softmax(x, dim=-1) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classifcation without incremental class adaption strategy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:15:14.124628Z",
     "iopub.status.busy": "2025-09-28T08:15:14.124375Z",
     "iopub.status.idle": "2025-09-28T08:15:14.132204Z",
     "shell.execute_reply": "2025-09-28T08:15:14.131388Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RollingClassifierInitialized</pre></summary><code class=\"river-estimator-params\">RollingClassifierInitialized (\n",
       "  module=LSTMModule(\n",
       "  (lstm): LSTM(31, 4)\n",
       "  (linear): Linear(in_features=4, out_features=2, bias=True)\n",
       ")\n",
       "  loss_fn=\"binary_cross_entropy\"\n",
       "  optimizer_fn=&lt;class 'torch.optim.sgd.SGD'&gt;\n",
       "  lr=0.01\n",
       "  output_is_logit=True\n",
       "  is_class_incremental=False\n",
       "  is_feature_incremental=False\n",
       "  device=\"cpu\"\n",
       "  seed=42\n",
       "  window_size=20\n",
       "  append_predict=True\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  RollingClassifierInitialized (\n",
       "    module=LSTMModule(\n",
       "    (lstm): LSTM(31, 4)\n",
       "    (linear): Linear(in_features=4, out_features=2, bias=True)\n",
       "  )\n",
       "    loss_fn=\"binary_cross_entropy\"\n",
       "    optimizer_fn=<class 'torch.optim.sgd.SGD'>\n",
       "    lr=0.01\n",
       "    output_is_logit=True\n",
       "    is_class_incremental=False\n",
       "    is_feature_incremental=False\n",
       "    device=\"cpu\"\n",
       "    seed=42\n",
       "    window_size=20\n",
       "    append_predict=True\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = datasets.Keystroke()\n",
    "metric = metrics.Accuracy()\n",
    "optimizer_fn = torch.optim.SGD\n",
    "\n",
    "model_pipeline = preprocessing.StandardScaler()\n",
    "model_pipeline |= RollingClassifierInitialized(\n",
    "    module=LSTMModule(n_features=31, hidden_size=4),\n",
    "    loss_fn=\"binary_cross_entropy\",\n",
    "    optimizer_fn=torch.optim.SGD,\n",
    "    window_size=20,\n",
    "    lr=1e-2,\n",
    "    append_predict=True,\n",
    ")\n",
    "model_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:15:14.135048Z",
     "iopub.status.busy": "2025-09-28T08:15:14.134762Z",
     "iopub.status.idle": "2025-09-28T08:15:51.773412Z",
     "shell.execute_reply": "2025-09-28T08:15:51.772582Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.03\n"
     ]
    }
   ],
   "source": [
    "for x, y in dataset:\n",
    "    y_pred = model_pipeline.predict_one(x)  # make a prediction\n",
    "    metric.update(y, y_pred)  # update the metric\n",
    "    model_pipeline.learn_one(x, y)  # make the model learn\n",
    "print(f\"Accuracy: {metric.get():.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deep-river",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
